from typing import Any, Callable, Dict, Type, Optional, Union, List
from states import GraphState
import sqlalchemy as sql
import pandas as pd
from db import connection


def node_func_execute_agent_from_sql_connection(
    state: Any, 
    connection: Any, 
    code_snippet_key: str, 
    result_key: str,
    error_key: str,
    agent_function_name: str,
    post_processing: Optional[Callable[[Any], Any]] = None,
    error_message_prefix: str = "An error occurred during agent execution: "
) -> Dict[str, Any]:
    """
    Execute a generic agent code defined in a code snippet retrieved from the state on a SQLAlchemy connection object 
    and return the result.
    
    Parameters
    ----------
    state : Any
        A state object that supports `get(key: str)` method to retrieve values.
    connection : str
        The SQLAlchemy connection object to use for executing the agent function.
    code_snippet_key : str
        The key in the state used to retrieve the Python code snippet defining the agent function.
    result_key : str
        The key in the state used to store the result of the agent function.
    error_key : str
        The key in the state used to store the error message if any.
    agent_function_name : str
        The name of the function (e.g., 'sql_database_agent') expected to be defined in the code snippet.
    post_processing : Callable[[Any], Any], optional
        A function to postprocess the output of the agent function before returning it.
    error_message_prefix : str, optional
        A prefix or full message to use in the error output if an exception occurs.
    
    Returns
    -------
    Dict[str, Any]
        A dictionary containing the result and/or error messages. Keys are arbitrary, 
        but typically include something like "result" or "error".
    """
    
    print("    * EXECUTING AGENT CODE ON SQL CONNECTION")
    
    # Retrieve SQLAlchemy connection and code snippet from the state
    is_engine = isinstance(connection, sql.engine.base.Engine)
    connection = connection.connect() if is_engine else connection
    agent_code = state.get(code_snippet_key)

    # Ensure the connection object is provided
    if connection is None:
        raise ValueError(f"Connection object not found.")
    
    # print(f"    * Code Snippet:\n{agent_code}")

    # Execute the code snippet to define the agent function
    local_vars = {}
    global_vars = {}
    exec(agent_code, global_vars, local_vars)
    
    # Retrieve the agent function from the executed code
    agent_function = local_vars.get(agent_function_name, None)
    if agent_function is None or not callable(agent_function):
        raise ValueError(f"Agent function '{agent_function_name}' not found or not callable in the provided code.")
    
    # Execute the agent function
    agent_error = None
    result = None
    try:
        result = agent_function(connection)
        
        # Apply post-processing if provided
        if post_processing is not None:
            result = post_processing(result)
    except Exception as e:
        print(e)
        agent_error = f"{error_message_prefix}{str(e)}"
    
    # Return results
    output = {result_key: result, error_key: agent_error}
    return output


def execute_sql_database_code(state: GraphState):
    
    is_engine = isinstance(connection, sql.engine.base.Engine)
    conn = connection.connect() if is_engine else connection
    
    return node_func_execute_agent_from_sql_connection(
        state=state,
        connection=conn,
        result_key="data_sql",
        error_key="sql_database_error",
        code_snippet_key="sql_database_function",
        agent_function_name=state.get("sql_database_function_name"),
        post_processing=lambda df: df.to_dict() if isinstance(df, pd.DataFrame) else df,
        error_message_prefix="An error occurred during executing the sql database pipeline: "
    )
